{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.0.1.post2\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Runs on CPU or GPU (if available)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Zoo -- Convolutional ResNet and Residual Blocks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please note that this example does not implement a really deep ResNet as described in literature but rather illustrates how the residual blocks described in He et al. [1] can be implemented in PyTorch.\n",
    "\n",
    "- [1] He, Kaiming, et al. \"Deep residual learning for image recognition.\" *Proceedings of the IEEE conference on computer vision and pattern recognition*. 2016."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Settings and Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Device\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 123\n",
    "learning_rate = 0.01\n",
    "num_epochs = 10\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "##########################\n",
    "### MNIST DATASET\n",
    "##########################\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.MNIST(root='data', \n",
    "                               train=True, \n",
    "                               transform=transforms.ToTensor(),\n",
    "                               download=True)\n",
    "\n",
    "test_dataset = datasets.MNIST(root='data', \n",
    "                              train=False, \n",
    "                              transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=batch_size, \n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=batch_size, \n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResNet with identity blocks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code implements the residual blocks with skip connections such that the input passed via the shortcut matches the dimensions of the main path's output, which allows the network to learn identity functions. Such a residual block is illustrated below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-1.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "\n",
    "class ConvNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        #########################\n",
    "        ### 1st residual block\n",
    "        #########################\n",
    "        # 28x28x1 => 28x28x4\n",
    "        self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(1, 1),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=0)\n",
    "        self.conv_1_bn = torch.nn.BatchNorm2d(4)\n",
    "                                    \n",
    "        # 28x28x4 => 28x28x1\n",
    "        self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=1,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1)   \n",
    "        self.conv_2_bn = torch.nn.BatchNorm2d(1)\n",
    "        \n",
    "        \n",
    "        #########################\n",
    "        ### 2nd residual block\n",
    "        #########################\n",
    "        # 28x28x1 => 28x28x4\n",
    "        self.conv_3 = torch.nn.Conv2d(in_channels=1,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(1, 1),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=0)\n",
    "        self.conv_3_bn = torch.nn.BatchNorm2d(4)\n",
    "                                    \n",
    "        # 28x28x4 => 28x28x1\n",
    "        self.conv_4 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=1,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=1)   \n",
    "        self.conv_4_bn = torch.nn.BatchNorm2d(1)\n",
    "\n",
    "        #########################\n",
    "        ### Fully connected\n",
    "        #########################        \n",
    "        self.linear_1 = torch.nn.Linear(28*28*1, num_classes)\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        #########################\n",
    "        ### 1st residual block\n",
    "        #########################\n",
    "        shortcut = x\n",
    "        \n",
    "        out = self.conv_1(x)\n",
    "        out = self.conv_1_bn(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_2(out)\n",
    "        out = self.conv_2_bn(out)\n",
    "        \n",
    "        out += shortcut\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        #########################\n",
    "        ### 2nd residual block\n",
    "        #########################\n",
    "        \n",
    "        shortcut = out\n",
    "        \n",
    "        out = self.conv_3(out)\n",
    "        out = self.conv_3_bn(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_4(out)\n",
    "        out = self.conv_4_bn(out)\n",
    "        \n",
    "        out += shortcut\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        #########################\n",
    "        ### Fully connected\n",
    "        #########################   \n",
    "        logits = self.linear_1(out.view(-1, 28*28*1))\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = ConvNet(num_classes=num_classes)\n",
    "model = model.to(device)\n",
    "    \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/469 | Cost: 2.5157\n",
      "Epoch: 001/010 | Batch 050/469 | Cost: 0.5106\n",
      "Epoch: 001/010 | Batch 100/469 | Cost: 0.2353\n",
      "Epoch: 001/010 | Batch 150/469 | Cost: 0.2672\n",
      "Epoch: 001/010 | Batch 200/469 | Cost: 0.3670\n",
      "Epoch: 001/010 | Batch 250/469 | Cost: 0.2920\n",
      "Epoch: 001/010 | Batch 300/469 | Cost: 0.3122\n",
      "Epoch: 001/010 | Batch 350/469 | Cost: 0.2697\n",
      "Epoch: 001/010 | Batch 400/469 | Cost: 0.4273\n",
      "Epoch: 001/010 | Batch 450/469 | Cost: 0.3696\n",
      "Epoch: 001/010 training accuracy: 92.21%\n",
      "Time elapsed: 0.25 min\n",
      "Epoch: 002/010 | Batch 000/469 | Cost: 0.2612\n",
      "Epoch: 002/010 | Batch 050/469 | Cost: 0.4460\n",
      "Epoch: 002/010 | Batch 100/469 | Cost: 0.2881\n",
      "Epoch: 002/010 | Batch 150/469 | Cost: 0.4010\n",
      "Epoch: 002/010 | Batch 200/469 | Cost: 0.2376\n",
      "Epoch: 002/010 | Batch 250/469 | Cost: 0.2598\n",
      "Epoch: 002/010 | Batch 300/469 | Cost: 0.1649\n",
      "Epoch: 002/010 | Batch 350/469 | Cost: 0.2331\n",
      "Epoch: 002/010 | Batch 400/469 | Cost: 0.2897\n",
      "Epoch: 002/010 | Batch 450/469 | Cost: 0.4034\n",
      "Epoch: 002/010 training accuracy: 92.73%\n",
      "Time elapsed: 0.51 min\n",
      "Epoch: 003/010 | Batch 000/469 | Cost: 0.2406\n",
      "Epoch: 003/010 | Batch 050/469 | Cost: 0.3472\n",
      "Epoch: 003/010 | Batch 100/469 | Cost: 0.2030\n",
      "Epoch: 003/010 | Batch 150/469 | Cost: 0.2327\n",
      "Epoch: 003/010 | Batch 200/469 | Cost: 0.2796\n",
      "Epoch: 003/010 | Batch 250/469 | Cost: 0.2485\n",
      "Epoch: 003/010 | Batch 300/469 | Cost: 0.1806\n",
      "Epoch: 003/010 | Batch 350/469 | Cost: 0.2239\n",
      "Epoch: 003/010 | Batch 400/469 | Cost: 0.4661\n",
      "Epoch: 003/010 | Batch 450/469 | Cost: 0.2216\n",
      "Epoch: 003/010 training accuracy: 93.16%\n",
      "Time elapsed: 0.76 min\n",
      "Epoch: 004/010 | Batch 000/469 | Cost: 0.4196\n",
      "Epoch: 004/010 | Batch 050/469 | Cost: 0.2219\n",
      "Epoch: 004/010 | Batch 100/469 | Cost: 0.1649\n",
      "Epoch: 004/010 | Batch 150/469 | Cost: 0.2900\n",
      "Epoch: 004/010 | Batch 200/469 | Cost: 0.2729\n",
      "Epoch: 004/010 | Batch 250/469 | Cost: 0.2085\n",
      "Epoch: 004/010 | Batch 300/469 | Cost: 0.3587\n",
      "Epoch: 004/010 | Batch 350/469 | Cost: 0.2085\n",
      "Epoch: 004/010 | Batch 400/469 | Cost: 0.2656\n",
      "Epoch: 004/010 | Batch 450/469 | Cost: 0.1630\n",
      "Epoch: 004/010 training accuracy: 93.64%\n",
      "Time elapsed: 1.01 min\n",
      "Epoch: 005/010 | Batch 000/469 | Cost: 0.2607\n",
      "Epoch: 005/010 | Batch 050/469 | Cost: 0.2885\n",
      "Epoch: 005/010 | Batch 100/469 | Cost: 0.4115\n",
      "Epoch: 005/010 | Batch 150/469 | Cost: 0.1415\n",
      "Epoch: 005/010 | Batch 200/469 | Cost: 0.1815\n",
      "Epoch: 005/010 | Batch 250/469 | Cost: 0.2137\n",
      "Epoch: 005/010 | Batch 300/469 | Cost: 0.0949\n",
      "Epoch: 005/010 | Batch 350/469 | Cost: 0.2109\n",
      "Epoch: 005/010 | Batch 400/469 | Cost: 0.2047\n",
      "Epoch: 005/010 | Batch 450/469 | Cost: 0.3176\n",
      "Epoch: 005/010 training accuracy: 93.86%\n",
      "Time elapsed: 1.26 min\n",
      "Epoch: 006/010 | Batch 000/469 | Cost: 0.2820\n",
      "Epoch: 006/010 | Batch 050/469 | Cost: 0.1209\n",
      "Epoch: 006/010 | Batch 100/469 | Cost: 0.2926\n",
      "Epoch: 006/010 | Batch 150/469 | Cost: 0.2950\n",
      "Epoch: 006/010 | Batch 200/469 | Cost: 0.1879\n",
      "Epoch: 006/010 | Batch 250/469 | Cost: 0.2352\n",
      "Epoch: 006/010 | Batch 300/469 | Cost: 0.2423\n",
      "Epoch: 006/010 | Batch 350/469 | Cost: 0.1898\n",
      "Epoch: 006/010 | Batch 400/469 | Cost: 0.3582\n",
      "Epoch: 006/010 | Batch 450/469 | Cost: 0.2269\n",
      "Epoch: 006/010 training accuracy: 93.86%\n",
      "Time elapsed: 1.51 min\n",
      "Epoch: 007/010 | Batch 000/469 | Cost: 0.2327\n",
      "Epoch: 007/010 | Batch 050/469 | Cost: 0.1684\n",
      "Epoch: 007/010 | Batch 100/469 | Cost: 0.1441\n",
      "Epoch: 007/010 | Batch 150/469 | Cost: 0.1740\n",
      "Epoch: 007/010 | Batch 200/469 | Cost: 0.1402\n",
      "Epoch: 007/010 | Batch 250/469 | Cost: 0.2488\n",
      "Epoch: 007/010 | Batch 300/469 | Cost: 0.2436\n",
      "Epoch: 007/010 | Batch 350/469 | Cost: 0.2196\n",
      "Epoch: 007/010 | Batch 400/469 | Cost: 0.1210\n",
      "Epoch: 007/010 | Batch 450/469 | Cost: 0.1820\n",
      "Epoch: 007/010 training accuracy: 94.19%\n",
      "Time elapsed: 1.76 min\n",
      "Epoch: 008/010 | Batch 000/469 | Cost: 0.1494\n",
      "Epoch: 008/010 | Batch 050/469 | Cost: 0.1392\n",
      "Epoch: 008/010 | Batch 100/469 | Cost: 0.2526\n",
      "Epoch: 008/010 | Batch 150/469 | Cost: 0.1961\n",
      "Epoch: 008/010 | Batch 200/469 | Cost: 0.2890\n",
      "Epoch: 008/010 | Batch 250/469 | Cost: 0.2019\n",
      "Epoch: 008/010 | Batch 300/469 | Cost: 0.3335\n",
      "Epoch: 008/010 | Batch 350/469 | Cost: 0.2250\n",
      "Epoch: 008/010 | Batch 400/469 | Cost: 0.1983\n",
      "Epoch: 008/010 | Batch 450/469 | Cost: 0.2136\n",
      "Epoch: 008/010 training accuracy: 94.40%\n",
      "Time elapsed: 2.01 min\n",
      "Epoch: 009/010 | Batch 000/469 | Cost: 0.3670\n",
      "Epoch: 009/010 | Batch 050/469 | Cost: 0.1793\n",
      "Epoch: 009/010 | Batch 100/469 | Cost: 0.3003\n",
      "Epoch: 009/010 | Batch 150/469 | Cost: 0.1713\n",
      "Epoch: 009/010 | Batch 200/469 | Cost: 0.2957\n",
      "Epoch: 009/010 | Batch 250/469 | Cost: 0.2260\n",
      "Epoch: 009/010 | Batch 300/469 | Cost: 0.1860\n",
      "Epoch: 009/010 | Batch 350/469 | Cost: 0.2632\n",
      "Epoch: 009/010 | Batch 400/469 | Cost: 0.2249\n",
      "Epoch: 009/010 | Batch 450/469 | Cost: 0.2512\n",
      "Epoch: 009/010 training accuracy: 94.61%\n",
      "Time elapsed: 2.26 min\n",
      "Epoch: 010/010 | Batch 000/469 | Cost: 0.1599\n",
      "Epoch: 010/010 | Batch 050/469 | Cost: 0.2204\n",
      "Epoch: 010/010 | Batch 100/469 | Cost: 0.1528\n",
      "Epoch: 010/010 | Batch 150/469 | Cost: 0.1847\n",
      "Epoch: 010/010 | Batch 200/469 | Cost: 0.1767\n",
      "Epoch: 010/010 | Batch 250/469 | Cost: 0.1473\n",
      "Epoch: 010/010 | Batch 300/469 | Cost: 0.1407\n",
      "Epoch: 010/010 | Batch 350/469 | Cost: 0.1406\n",
      "Epoch: 010/010 | Batch 400/469 | Cost: 0.3001\n",
      "Epoch: 010/010 | Batch 450/469 | Cost: 0.2306\n",
      "Epoch: 010/010 training accuracy: 93.22%\n",
      "Time elapsed: 2.51 min\n",
      "Total Training Time: 2.51 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "\n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(num_epochs):\n",
    "    model = model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    model = model.eval() # eval mode to prevent upd. batchnorm params during inference\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))\n",
    "\n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 92.20%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResNet with convolutional blocks for resizing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code implements the residual blocks with skip connections such that the input passed via the shortcut matches is resized to dimensions of the main path's output. Such a residual block is illustrated below:\n",
    "\n",
    "![](../images/resnets/resnet-ex-1-2.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "\n",
    "class ConvNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        #########################\n",
    "        ### 1st residual block\n",
    "        #########################\n",
    "        # 28x28x1 => 14x14x4 \n",
    "        self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
    "                                      out_channels=4,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)\n",
    "        self.conv_1_bn = torch.nn.BatchNorm2d(4)\n",
    "                                    \n",
    "        # 14x14x4 => 14x14x8\n",
    "        self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
    "                                      out_channels=8,\n",
    "                                      kernel_size=(1, 1),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=0)   \n",
    "        self.conv_2_bn = torch.nn.BatchNorm2d(8)\n",
    "        \n",
    "        # 28x28x1 => 14x14x8\n",
    "        self.conv_shortcut_1 = torch.nn.Conv2d(in_channels=1,\n",
    "                                               out_channels=8,\n",
    "                                               kernel_size=(1, 1),\n",
    "                                               stride=(2, 2),\n",
    "                                               padding=0)   \n",
    "        self.conv_shortcut_1_bn = torch.nn.BatchNorm2d(8)\n",
    "        \n",
    "        #########################\n",
    "        ### 2nd residual block\n",
    "        #########################\n",
    "        # 14x14x8 => 7x7x16 \n",
    "        self.conv_3 = torch.nn.Conv2d(in_channels=8,\n",
    "                                      out_channels=16,\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)\n",
    "        self.conv_3_bn = torch.nn.BatchNorm2d(16)\n",
    "                                    \n",
    "        # 7x7x16 => 7x7x32\n",
    "        self.conv_4 = torch.nn.Conv2d(in_channels=16,\n",
    "                                      out_channels=32,\n",
    "                                      kernel_size=(1, 1),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=0)   \n",
    "        self.conv_4_bn = torch.nn.BatchNorm2d(32)\n",
    "        \n",
    "        # 14x14x8 => 7x7x32 \n",
    "        self.conv_shortcut_2 = torch.nn.Conv2d(in_channels=8,\n",
    "                                               out_channels=32,\n",
    "                                               kernel_size=(1, 1),\n",
    "                                               stride=(2, 2),\n",
    "                                               padding=0)   \n",
    "        self.conv_shortcut_2_bn = torch.nn.BatchNorm2d(32)\n",
    "\n",
    "        #########################\n",
    "        ### Fully connected\n",
    "        #########################        \n",
    "        self.linear_1 = torch.nn.Linear(7*7*32, num_classes)\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \n",
    "        #########################\n",
    "        ### 1st residual block\n",
    "        #########################\n",
    "        shortcut = x\n",
    "        \n",
    "        out = self.conv_1(x) # 28x28x1 => 14x14x4 \n",
    "        out = self.conv_1_bn(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_2(out) # 14x14x4 => 714x14x8\n",
    "        out = self.conv_2_bn(out)\n",
    "        \n",
    "        # match up dimensions using a linear function (no relu)\n",
    "        shortcut = self.conv_shortcut_1(shortcut)\n",
    "        shortcut = self.conv_shortcut_1_bn(shortcut)\n",
    "        \n",
    "        out += shortcut\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        #########################\n",
    "        ### 2nd residual block\n",
    "        #########################\n",
    "        \n",
    "        shortcut = out\n",
    "        \n",
    "        out = self.conv_3(out) # 14x14x8 => 7x7x16 \n",
    "        out = self.conv_3_bn(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_4(out) # 7x7x16 => 7x7x32\n",
    "        out = self.conv_4_bn(out)\n",
    "        \n",
    "        # match up dimensions using a linear function (no relu)\n",
    "        shortcut = self.conv_shortcut_2(shortcut)\n",
    "        shortcut = self.conv_shortcut_2_bn(shortcut)\n",
    "        \n",
    "        out += shortcut\n",
    "        out = F.relu(out)\n",
    "        \n",
    "        #########################\n",
    "        ### Fully connected\n",
    "        #########################   \n",
    "        logits = self.linear_1(out.view(-1, 7*7*32))\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = ConvNet(num_classes=num_classes)\n",
    "model = model.to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/469 | Cost: 2.3318\n",
      "Epoch: 001/010 | Batch 050/469 | Cost: 0.1491\n",
      "Epoch: 001/010 | Batch 100/469 | Cost: 0.2615\n",
      "Epoch: 001/010 | Batch 150/469 | Cost: 0.0847\n",
      "Epoch: 001/010 | Batch 200/469 | Cost: 0.1427\n",
      "Epoch: 001/010 | Batch 250/469 | Cost: 0.1739\n",
      "Epoch: 001/010 | Batch 300/469 | Cost: 0.1558\n",
      "Epoch: 001/010 | Batch 350/469 | Cost: 0.0684\n",
      "Epoch: 001/010 | Batch 400/469 | Cost: 0.0717\n",
      "Epoch: 001/010 | Batch 450/469 | Cost: 0.0785\n",
      "Epoch: 001/010 training accuracy: 97.90%\n",
      "Epoch: 002/010 | Batch 000/469 | Cost: 0.0582\n",
      "Epoch: 002/010 | Batch 050/469 | Cost: 0.1199\n",
      "Epoch: 002/010 | Batch 100/469 | Cost: 0.0918\n",
      "Epoch: 002/010 | Batch 150/469 | Cost: 0.0247\n",
      "Epoch: 002/010 | Batch 200/469 | Cost: 0.0314\n",
      "Epoch: 002/010 | Batch 250/469 | Cost: 0.0759\n",
      "Epoch: 002/010 | Batch 300/469 | Cost: 0.0280\n",
      "Epoch: 002/010 | Batch 350/469 | Cost: 0.0391\n",
      "Epoch: 002/010 | Batch 400/469 | Cost: 0.0431\n",
      "Epoch: 002/010 | Batch 450/469 | Cost: 0.0455\n",
      "Epoch: 002/010 training accuracy: 98.16%\n",
      "Epoch: 003/010 | Batch 000/469 | Cost: 0.0303\n",
      "Epoch: 003/010 | Batch 050/469 | Cost: 0.0433\n",
      "Epoch: 003/010 | Batch 100/469 | Cost: 0.0465\n",
      "Epoch: 003/010 | Batch 150/469 | Cost: 0.0243\n",
      "Epoch: 003/010 | Batch 200/469 | Cost: 0.0258\n",
      "Epoch: 003/010 | Batch 250/469 | Cost: 0.0403\n",
      "Epoch: 003/010 | Batch 300/469 | Cost: 0.1024\n",
      "Epoch: 003/010 | Batch 350/469 | Cost: 0.0241\n",
      "Epoch: 003/010 | Batch 400/469 | Cost: 0.0299\n",
      "Epoch: 003/010 | Batch 450/469 | Cost: 0.0354\n",
      "Epoch: 003/010 training accuracy: 98.08%\n",
      "Epoch: 004/010 | Batch 000/469 | Cost: 0.0471\n",
      "Epoch: 004/010 | Batch 050/469 | Cost: 0.0954\n",
      "Epoch: 004/010 | Batch 100/469 | Cost: 0.0073\n",
      "Epoch: 004/010 | Batch 150/469 | Cost: 0.0531\n",
      "Epoch: 004/010 | Batch 200/469 | Cost: 0.0493\n",
      "Epoch: 004/010 | Batch 250/469 | Cost: 0.1070\n",
      "Epoch: 004/010 | Batch 300/469 | Cost: 0.0205\n",
      "Epoch: 004/010 | Batch 350/469 | Cost: 0.0270\n",
      "Epoch: 004/010 | Batch 400/469 | Cost: 0.0817\n",
      "Epoch: 004/010 | Batch 450/469 | Cost: 0.0182\n",
      "Epoch: 004/010 training accuracy: 98.70%\n",
      "Epoch: 005/010 | Batch 000/469 | Cost: 0.0691\n",
      "Epoch: 005/010 | Batch 050/469 | Cost: 0.0326\n",
      "Epoch: 005/010 | Batch 100/469 | Cost: 0.0041\n",
      "Epoch: 005/010 | Batch 150/469 | Cost: 0.0774\n",
      "Epoch: 005/010 | Batch 200/469 | Cost: 0.1223\n",
      "Epoch: 005/010 | Batch 250/469 | Cost: 0.0329\n",
      "Epoch: 005/010 | Batch 300/469 | Cost: 0.0479\n",
      "Epoch: 005/010 | Batch 350/469 | Cost: 0.0696\n",
      "Epoch: 005/010 | Batch 400/469 | Cost: 0.0504\n",
      "Epoch: 005/010 | Batch 450/469 | Cost: 0.0736\n",
      "Epoch: 005/010 training accuracy: 98.38%\n",
      "Epoch: 006/010 | Batch 000/469 | Cost: 0.0318\n",
      "Epoch: 006/010 | Batch 050/469 | Cost: 0.0303\n",
      "Epoch: 006/010 | Batch 100/469 | Cost: 0.0267\n",
      "Epoch: 006/010 | Batch 150/469 | Cost: 0.0912\n",
      "Epoch: 006/010 | Batch 200/469 | Cost: 0.0131\n",
      "Epoch: 006/010 | Batch 250/469 | Cost: 0.0164\n",
      "Epoch: 006/010 | Batch 300/469 | Cost: 0.0109\n",
      "Epoch: 006/010 | Batch 350/469 | Cost: 0.0699\n",
      "Epoch: 006/010 | Batch 400/469 | Cost: 0.0030\n",
      "Epoch: 006/010 | Batch 450/469 | Cost: 0.0237\n",
      "Epoch: 006/010 training accuracy: 98.74%\n",
      "Epoch: 007/010 | Batch 000/469 | Cost: 0.0214\n",
      "Epoch: 007/010 | Batch 050/469 | Cost: 0.0097\n",
      "Epoch: 007/010 | Batch 100/469 | Cost: 0.0292\n",
      "Epoch: 007/010 | Batch 150/469 | Cost: 0.0648\n",
      "Epoch: 007/010 | Batch 200/469 | Cost: 0.0044\n",
      "Epoch: 007/010 | Batch 250/469 | Cost: 0.0557\n",
      "Epoch: 007/010 | Batch 300/469 | Cost: 0.0139\n",
      "Epoch: 007/010 | Batch 350/469 | Cost: 0.0809\n",
      "Epoch: 007/010 | Batch 400/469 | Cost: 0.0285\n",
      "Epoch: 007/010 | Batch 450/469 | Cost: 0.0050\n",
      "Epoch: 007/010 training accuracy: 98.82%\n",
      "Epoch: 008/010 | Batch 000/469 | Cost: 0.0890\n",
      "Epoch: 008/010 | Batch 050/469 | Cost: 0.0685\n",
      "Epoch: 008/010 | Batch 100/469 | Cost: 0.0274\n",
      "Epoch: 008/010 | Batch 150/469 | Cost: 0.0187\n",
      "Epoch: 008/010 | Batch 200/469 | Cost: 0.0268\n",
      "Epoch: 008/010 | Batch 250/469 | Cost: 0.1681\n",
      "Epoch: 008/010 | Batch 300/469 | Cost: 0.0167\n",
      "Epoch: 008/010 | Batch 350/469 | Cost: 0.0518\n",
      "Epoch: 008/010 | Batch 400/469 | Cost: 0.0138\n",
      "Epoch: 008/010 | Batch 450/469 | Cost: 0.0270\n",
      "Epoch: 008/010 training accuracy: 99.08%\n",
      "Epoch: 009/010 | Batch 000/469 | Cost: 0.0458\n",
      "Epoch: 009/010 | Batch 050/469 | Cost: 0.0039\n",
      "Epoch: 009/010 | Batch 100/469 | Cost: 0.0597\n",
      "Epoch: 009/010 | Batch 150/469 | Cost: 0.0120\n",
      "Epoch: 009/010 | Batch 200/469 | Cost: 0.0580\n",
      "Epoch: 009/010 | Batch 250/469 | Cost: 0.0280\n",
      "Epoch: 009/010 | Batch 300/469 | Cost: 0.0570\n",
      "Epoch: 009/010 | Batch 350/469 | Cost: 0.0831\n",
      "Epoch: 009/010 | Batch 400/469 | Cost: 0.0732\n",
      "Epoch: 009/010 | Batch 450/469 | Cost: 0.0327\n",
      "Epoch: 009/010 training accuracy: 99.05%\n",
      "Epoch: 010/010 | Batch 000/469 | Cost: 0.0312\n",
      "Epoch: 010/010 | Batch 050/469 | Cost: 0.0130\n",
      "Epoch: 010/010 | Batch 100/469 | Cost: 0.0052\n",
      "Epoch: 010/010 | Batch 150/469 | Cost: 0.0188\n",
      "Epoch: 010/010 | Batch 200/469 | Cost: 0.0362\n",
      "Epoch: 010/010 | Batch 250/469 | Cost: 0.1085\n",
      "Epoch: 010/010 | Batch 300/469 | Cost: 0.0004\n",
      "Epoch: 010/010 | Batch 350/469 | Cost: 0.0299\n",
      "Epoch: 010/010 | Batch 400/469 | Cost: 0.0769\n",
      "Epoch: 010/010 | Batch 450/469 | Cost: 0.0247\n",
      "Epoch: 010/010 training accuracy: 98.87%\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model = model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "    model = model.eval() # eval mode to prevent upd. batchnorm params during inference\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.91%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResNet with convolutional blocks for resizing (using a helper class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the same network as above but uses a `ResidualBlock` helper class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResidualBlock(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, channels):\n",
    "        \n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.conv_1 = torch.nn.Conv2d(in_channels=channels[0],\n",
    "                                      out_channels=channels[1],\n",
    "                                      kernel_size=(3, 3),\n",
    "                                      stride=(2, 2),\n",
    "                                      padding=1)\n",
    "        self.conv_1_bn = torch.nn.BatchNorm2d(channels[1])\n",
    "                                    \n",
    "        self.conv_2 = torch.nn.Conv2d(in_channels=channels[1],\n",
    "                                      out_channels=channels[2],\n",
    "                                      kernel_size=(1, 1),\n",
    "                                      stride=(1, 1),\n",
    "                                      padding=0)   \n",
    "        self.conv_2_bn = torch.nn.BatchNorm2d(channels[2])\n",
    "\n",
    "        self.conv_shortcut_1 = torch.nn.Conv2d(in_channels=channels[0],\n",
    "                                               out_channels=channels[2],\n",
    "                                               kernel_size=(1, 1),\n",
    "                                               stride=(2, 2),\n",
    "                                               padding=0)   \n",
    "        self.conv_shortcut_1_bn = torch.nn.BatchNorm2d(channels[2])\n",
    "\n",
    "    def forward(self, x):\n",
    "        shortcut = x\n",
    "        \n",
    "        out = self.conv_1(x)\n",
    "        out = self.conv_1_bn(out)\n",
    "        out = F.relu(out)\n",
    "\n",
    "        out = self.conv_2(out)\n",
    "        out = self.conv_2_bn(out)\n",
    "        \n",
    "        # match up dimensions using a linear function (no relu)\n",
    "        shortcut = self.conv_shortcut_1(shortcut)\n",
    "        shortcut = self.conv_shortcut_1_bn(shortcut)\n",
    "        \n",
    "        out += shortcut\n",
    "        out = F.relu(out)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "\n",
    "class ConvNet(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(ConvNet, self).__init__()\n",
    "        \n",
    "        self.residual_block_1 = ResidualBlock(channels=[1, 4, 8])\n",
    "        self.residual_block_2 = ResidualBlock(channels=[8, 16, 32])\n",
    "    \n",
    "        self.linear_1 = torch.nn.Linear(7*7*32, num_classes)\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "\n",
    "        out = self.residual_block_1.forward(x)\n",
    "        out = self.residual_block_2.forward(out)\n",
    "         \n",
    "        logits = self.linear_1(out.view(-1, 7*7*32))\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas\n",
    "\n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = ConvNet(num_classes=num_classes)\n",
    "\n",
    "model.to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 000/468 | Cost: 2.3318\n",
      "Epoch: 001/010 | Batch 050/468 | Cost: 0.1491\n",
      "Epoch: 001/010 | Batch 100/468 | Cost: 0.2615\n",
      "Epoch: 001/010 | Batch 150/468 | Cost: 0.0847\n",
      "Epoch: 001/010 | Batch 200/468 | Cost: 0.1427\n",
      "Epoch: 001/010 | Batch 250/468 | Cost: 0.1739\n",
      "Epoch: 001/010 | Batch 300/468 | Cost: 0.1558\n",
      "Epoch: 001/010 | Batch 350/468 | Cost: 0.0684\n",
      "Epoch: 001/010 | Batch 400/468 | Cost: 0.0717\n",
      "Epoch: 001/010 | Batch 450/468 | Cost: 0.0785\n",
      "Epoch: 001/010 training accuracy: 97.90%\n",
      "Epoch: 002/010 | Batch 000/468 | Cost: 0.0582\n",
      "Epoch: 002/010 | Batch 050/468 | Cost: 0.1199\n",
      "Epoch: 002/010 | Batch 100/468 | Cost: 0.0918\n",
      "Epoch: 002/010 | Batch 150/468 | Cost: 0.0247\n",
      "Epoch: 002/010 | Batch 200/468 | Cost: 0.0314\n",
      "Epoch: 002/010 | Batch 250/468 | Cost: 0.0759\n",
      "Epoch: 002/010 | Batch 300/468 | Cost: 0.0280\n",
      "Epoch: 002/010 | Batch 350/468 | Cost: 0.0391\n",
      "Epoch: 002/010 | Batch 400/468 | Cost: 0.0431\n",
      "Epoch: 002/010 | Batch 450/468 | Cost: 0.0455\n",
      "Epoch: 002/010 training accuracy: 98.16%\n",
      "Epoch: 003/010 | Batch 000/468 | Cost: 0.0303\n",
      "Epoch: 003/010 | Batch 050/468 | Cost: 0.0433\n",
      "Epoch: 003/010 | Batch 100/468 | Cost: 0.0465\n",
      "Epoch: 003/010 | Batch 150/468 | Cost: 0.0243\n",
      "Epoch: 003/010 | Batch 200/468 | Cost: 0.0258\n",
      "Epoch: 003/010 | Batch 250/468 | Cost: 0.0403\n",
      "Epoch: 003/010 | Batch 300/468 | Cost: 0.1024\n",
      "Epoch: 003/010 | Batch 350/468 | Cost: 0.0241\n",
      "Epoch: 003/010 | Batch 400/468 | Cost: 0.0299\n",
      "Epoch: 003/010 | Batch 450/468 | Cost: 0.0354\n",
      "Epoch: 003/010 training accuracy: 98.08%\n",
      "Epoch: 004/010 | Batch 000/468 | Cost: 0.0471\n",
      "Epoch: 004/010 | Batch 050/468 | Cost: 0.0954\n",
      "Epoch: 004/010 | Batch 100/468 | Cost: 0.0073\n",
      "Epoch: 004/010 | Batch 150/468 | Cost: 0.0531\n",
      "Epoch: 004/010 | Batch 200/468 | Cost: 0.0493\n",
      "Epoch: 004/010 | Batch 250/468 | Cost: 0.1070\n",
      "Epoch: 004/010 | Batch 300/468 | Cost: 0.0205\n",
      "Epoch: 004/010 | Batch 350/468 | Cost: 0.0270\n",
      "Epoch: 004/010 | Batch 400/468 | Cost: 0.0817\n",
      "Epoch: 004/010 | Batch 450/468 | Cost: 0.0182\n",
      "Epoch: 004/010 training accuracy: 98.70%\n",
      "Epoch: 005/010 | Batch 000/468 | Cost: 0.0691\n",
      "Epoch: 005/010 | Batch 050/468 | Cost: 0.0326\n",
      "Epoch: 005/010 | Batch 100/468 | Cost: 0.0041\n",
      "Epoch: 005/010 | Batch 150/468 | Cost: 0.0774\n",
      "Epoch: 005/010 | Batch 200/468 | Cost: 0.1223\n",
      "Epoch: 005/010 | Batch 250/468 | Cost: 0.0329\n",
      "Epoch: 005/010 | Batch 300/468 | Cost: 0.0479\n",
      "Epoch: 005/010 | Batch 350/468 | Cost: 0.0696\n",
      "Epoch: 005/010 | Batch 400/468 | Cost: 0.0504\n",
      "Epoch: 005/010 | Batch 450/468 | Cost: 0.0736\n",
      "Epoch: 005/010 training accuracy: 98.38%\n",
      "Epoch: 006/010 | Batch 000/468 | Cost: 0.0318\n",
      "Epoch: 006/010 | Batch 050/468 | Cost: 0.0303\n",
      "Epoch: 006/010 | Batch 100/468 | Cost: 0.0267\n",
      "Epoch: 006/010 | Batch 150/468 | Cost: 0.0912\n",
      "Epoch: 006/010 | Batch 200/468 | Cost: 0.0131\n",
      "Epoch: 006/010 | Batch 250/468 | Cost: 0.0164\n",
      "Epoch: 006/010 | Batch 300/468 | Cost: 0.0109\n",
      "Epoch: 006/010 | Batch 350/468 | Cost: 0.0699\n",
      "Epoch: 006/010 | Batch 400/468 | Cost: 0.0030\n",
      "Epoch: 006/010 | Batch 450/468 | Cost: 0.0237\n",
      "Epoch: 006/010 training accuracy: 98.74%\n",
      "Epoch: 007/010 | Batch 000/468 | Cost: 0.0214\n",
      "Epoch: 007/010 | Batch 050/468 | Cost: 0.0097\n",
      "Epoch: 007/010 | Batch 100/468 | Cost: 0.0292\n",
      "Epoch: 007/010 | Batch 150/468 | Cost: 0.0648\n",
      "Epoch: 007/010 | Batch 200/468 | Cost: 0.0044\n",
      "Epoch: 007/010 | Batch 250/468 | Cost: 0.0557\n",
      "Epoch: 007/010 | Batch 300/468 | Cost: 0.0139\n",
      "Epoch: 007/010 | Batch 350/468 | Cost: 0.0809\n",
      "Epoch: 007/010 | Batch 400/468 | Cost: 0.0285\n",
      "Epoch: 007/010 | Batch 450/468 | Cost: 0.0050\n",
      "Epoch: 007/010 training accuracy: 98.82%\n",
      "Epoch: 008/010 | Batch 000/468 | Cost: 0.0890\n",
      "Epoch: 008/010 | Batch 050/468 | Cost: 0.0685\n",
      "Epoch: 008/010 | Batch 100/468 | Cost: 0.0274\n",
      "Epoch: 008/010 | Batch 150/468 | Cost: 0.0187\n",
      "Epoch: 008/010 | Batch 200/468 | Cost: 0.0268\n",
      "Epoch: 008/010 | Batch 250/468 | Cost: 0.1681\n",
      "Epoch: 008/010 | Batch 300/468 | Cost: 0.0167\n",
      "Epoch: 008/010 | Batch 350/468 | Cost: 0.0518\n",
      "Epoch: 008/010 | Batch 400/468 | Cost: 0.0138\n",
      "Epoch: 008/010 | Batch 450/468 | Cost: 0.0270\n",
      "Epoch: 008/010 training accuracy: 99.08%\n",
      "Epoch: 009/010 | Batch 000/468 | Cost: 0.0458\n",
      "Epoch: 009/010 | Batch 050/468 | Cost: 0.0039\n",
      "Epoch: 009/010 | Batch 100/468 | Cost: 0.0597\n",
      "Epoch: 009/010 | Batch 150/468 | Cost: 0.0120\n",
      "Epoch: 009/010 | Batch 200/468 | Cost: 0.0580\n",
      "Epoch: 009/010 | Batch 250/468 | Cost: 0.0280\n",
      "Epoch: 009/010 | Batch 300/468 | Cost: 0.0570\n",
      "Epoch: 009/010 | Batch 350/468 | Cost: 0.0831\n",
      "Epoch: 009/010 | Batch 400/468 | Cost: 0.0732\n",
      "Epoch: 009/010 | Batch 450/468 | Cost: 0.0327\n",
      "Epoch: 009/010 training accuracy: 99.05%\n",
      "Epoch: 010/010 | Batch 000/468 | Cost: 0.0312\n",
      "Epoch: 010/010 | Batch 050/468 | Cost: 0.0130\n",
      "Epoch: 010/010 | Batch 100/468 | Cost: 0.0052\n",
      "Epoch: 010/010 | Batch 150/468 | Cost: 0.0188\n",
      "Epoch: 010/010 | Batch 200/468 | Cost: 0.0362\n",
      "Epoch: 010/010 | Batch 250/468 | Cost: 0.1085\n",
      "Epoch: 010/010 | Batch 300/468 | Cost: 0.0004\n",
      "Epoch: 010/010 | Batch 350/468 | Cost: 0.0299\n",
      "Epoch: 010/010 | Batch 400/468 | Cost: 0.0769\n",
      "Epoch: 010/010 | Batch 450/468 | Cost: 0.0247\n",
      "Epoch: 010/010 training accuracy: 98.87%\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model = model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_dataset)//batch_size, cost))\n",
    "\n",
    "    model = model.eval() # eval mode to prevent upd. batchnorm params during inference\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
    "              epoch+1, num_epochs, \n",
    "              compute_accuracy(model, train_loader)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.91%\n"
     ]
    }
   ],
   "source": [
    "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.15.4\n",
      "torch       1.0.1.post2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
